Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bootstrap Aggregation #229

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open

Conversation

jk1015
Copy link

@jk1015 jk1015 commented Jul 7, 2022

Following on from the discussion in issue #199 here is an implementation of Bootstrap Aggregation following the approach laid out there. There is also an example showing how the provided methods can be used together with the existing linfa-trees package to implement a Random Forest.

Outside of the linfa-ensemble package the only required change was to factor the trait FromTargetArray to create an additional trait FromTargetArrayOwned. This was done to allow owned data to be used without needing to specify a lifetime, which was impossible under the previous FromTargetArray trait.

-Added an example using bootstrap aggregation to carry out Random Forest classification.

-Factored the linfa trait FromTargetArray to create an additional trait FromTargetArrayOwned.
@codecov-commenter
Copy link

codecov-commenter commented Jul 16, 2022

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

Attention: Patch coverage is 0% with 67 lines in your changes missing coverage. Please review.

Project coverage is 55.11%. Comparing base (d4bd9c9) to head (46722ec).
Report is 52 commits behind head on master.

Files with missing lines Patch % Lines
algorithms/linfa-ensemble/src/ensemble.rs 0.00% 67 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #229      +/-   ##
==========================================
- Coverage   55.44%   55.11%   -0.34%     
==========================================
  Files          95       97       +2     
  Lines        8774     9014     +240     
==========================================
+ Hits         4865     4968     +103     
- Misses       3909     4046     +137     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

}

pub struct EnsembleLearnerParams<P> {
pub ensemble_size: usize,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a separate field for ensemble_size? Isn't this value implied by bootstrap_proportion?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ensemble_size gives the number of models in the ensemble while bootstrap_proportion gives the proportion of the total number of training samples that should be given to each model for training. These should be distinct parameters.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't bootstrap_proportion be the same as 1/ensemble_size?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily, each model in the ensemble just needs its own random set of samples of training data from the complete training data set. There are no constraints on the size of this set other than it being non-empty, so we let the user tune this size as a hyperparameter.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK so bootstrap_samples just grabs random sets of samples from the input and yields them infinitely. I thought it divided the input into random subsamples. This makes sense now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add this behaviour to the docs, along with a general description of EnsembleLearner? We should also have top level docs in src/lib.rs like with the other crates.

Comment on lines 74 to 80
let aggregated_predictions = self.aggregate_predictions(&mut predictions);

for (target, output) in y_array.axis_iter_mut(Axis(0)).zip(aggregated_predictions.into_iter()) {
for (t, o) in target.into_iter().zip(output[0].0.iter()) {
*t = *o;
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace with this:

// prediction map has same shape as y_array, but the elements are maps
let mut prediction_maps = y_array.map(|_| HashMap::new());

for prediction in predictions {
  let p_arr = prediction.as_targets();
  assert_eq!(p_arr.shape(), y_array.shape());
  // Insert each prediction value into the corresponding map
  Zip::from(&mut prediction_maps).and(&p_arr).for_each(|(&mut map, &val)| map.entry(val).or_insert(0) += 1);
}

// For each prediction, pick the result with the highest number of votes
y_array = prediction_maps.mapv_into(|map| map.iter().max_by_key(|(_, v)| v).0);

It picks out the predictions with the highest number of votes without the complexity of aggregate_predictions

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment still applies I believe.

@YuhanLiin YuhanLiin mentioned this pull request Oct 22, 2022
24 tasks
@HridayM25
Copy link

Hi!
Could you please guide me as to what is remaining in this?
Thank You!

@YuhanLiin
Copy link
Collaborator

Merge/rebase with the latest master and address the open review comments. That's pretty much it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants